{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "N1BarZA4nZGb"
   },
   "source": [
    "load data and make dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "gpus = tf.config.experimental.list_physical_devices('GPU')\n",
    "train_dir = '../dataset/classification/train'\n",
    "valid_dir = '../dataset/classification/valid'\n",
    "batch_size = 64\n",
    "crop_size = 224\n",
    "\n",
    "train_dataset = tf.keras.utils.image_dataset_from_directory(\n",
    "    train_dir,\n",
    "    seed=1,\n",
    "    image_size=(crop_size, crop_size),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True,\n",
    ")\n",
    "\n",
    "valid_dataset = tf.keras.utils.image_dataset_from_directory(\n",
    "  valid_dir,\n",
    "  seed=1,\n",
    "  image_size=(crop_size, crop_size),\n",
    "  batch_size=batch_size,\n",
    ")\n",
    "\n",
    "def _normalize_img(img, label):\n",
    "    img = tf.keras.applications.resnet_v2.preprocess_input(img)\n",
    "    return (img, label)\n",
    "\n",
    "train_dataset = train_dataset.map(_normalize_img)\n",
    "valid_dataset = valid_dataset.map(_normalize_img)\n",
    "\n",
    "cnt_1,cnt_0 = 0,0\n",
    "for batch in train_dataset:\n",
    "    cnt_1 = cnt_1 + np.sum(batch[1]==1)\n",
    "    cnt_0 = cnt_0 + np.sum(batch[1]==0)\n",
    "print('train:  positive ',cnt_1,' negative ',cnt_0)\n",
    "\n",
    "cnt_1,cnt_0 = 0,0\n",
    "for batch in valid_dataset:\n",
    "    cnt_1 = cnt_1 + np.sum(batch[1]==1)\n",
    "    cnt_0 = cnt_0 + np.sum(batch[1]==0)\n",
    "print('valid:  positive ',cnt_1,' negative ',cnt_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dataset(dataset):\n",
    "    X=tf.zeros([0,2048])\n",
    "    Y=tf.zeros([0],dtype=tf.int32)\n",
    "    for sample in dataset:\n",
    "        x = model.predict(sample[0])\n",
    "        y = sample[1]\n",
    "        X = tf.concat([X,x],0)\n",
    "        Y = tf.concat([Y,y],0)\n",
    "    return X,Y\n",
    "\n",
    "# KNN\n",
    "\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "\n",
    "def measure(model):\n",
    "    \n",
    "    X_train,Y_train = get_dataset(train_dataset)\n",
    "    X_valid,Y_valid = get_dataset(valid_dataset)\n",
    "\n",
    "    valid_K, valid_acc = 0,0\n",
    "\n",
    "    K=1\n",
    "        \n",
    "    correct = 0\n",
    "    sum = 0\n",
    "    neigh = KNeighborsClassifier(n_neighbors=K)\n",
    "    neigh.fit(X_train, Y_train)\n",
    "\n",
    "    acc = neigh.score(X_valid,Y_valid)\n",
    "    \n",
    "    valid_acc = acc\n",
    "    valid_K = K\n",
    "            \n",
    "    return valid_K, valid_acc\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow.keras as keras\n",
    "from sklearn.cluster import KMeans\n",
    "import tensorflow_hub as hub\n",
    "\n",
    "origin_resnet = keras.applications.ResNet50V2(weights='imagenet')\n",
    "resnet = tf.keras.Model(inputs=origin_resnet.input,outputs=origin_resnet.layers[-2].output)\n",
    "\n",
    "model = tf.keras.Sequential([\n",
    "    resnet,\n",
    "    tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings\n",
    "])\n",
    "\n",
    "optimizer = tf.keras.optimizers.Adam(1e-4)\n",
    "checkpoint_filepath = '/tmp/cluster'\n",
    "\n",
    "def loss_fn(x, y, cluster_labels, training=True):\n",
    "    \n",
    "    embeddings = model(x, training=training)\n",
    "    \n",
    "    kmeans = KMeans(n_clusters=4).fit(embeddings) \n",
    "    cluster_labels = tf.constant(kmeans.labels_)  \n",
    "\n",
    "    class_labels = tf.constant(y)\n",
    "    \n",
    "    batch = len(cluster_labels)\n",
    "\n",
    "    L2 =  lambda x,y : tf.sqrt(tf.reduce_sum(tf.square(x-y),axis=-1))\n",
    "\n",
    "    triplets=[]\n",
    "    \n",
    "    for i in range(batch):\n",
    "        anchor = embeddings[i]\n",
    "        \n",
    "        cluster_same_id = tf.where(cluster_labels==cluster_labels[i])\n",
    "        class_same_id = tf.where(class_labels==class_labels[i])\n",
    "        candidate_id = tf.sets.intersection(tf.transpose(cluster_same_id,[1,0]), tf.transpose(class_same_id,[1,0])).values\n",
    "        if(len(candidate_id)<=1):\n",
    "            continue\n",
    "        candidate_positives = tf.gather_nd(embeddings,tf.expand_dims(candidate_id,-1))\n",
    "        dis = L2(anchor,candidate_positives)\n",
    "        positive = candidate_positives[tf.argmax(dis)]\n",
    "        \n",
    "        \n",
    "        class_diff_id = tf.where(class_labels!=class_labels[i])\n",
    "        if(len(class_diff_id)==0):\n",
    "            continue\n",
    "        candidate_negatives = tf.gather_nd(embeddings,class_diff_id)\n",
    "        dis = L2(anchor,candidate_negatives)\n",
    "        negative = candidate_negatives[tf.argmin(dis)]\n",
    "        \n",
    "        \n",
    "        ap = L2(anchor, positive)\n",
    "        an = L2(anchor, negative)\n",
    "\n",
    "        triplet = tf.maximum(ap - an + 0.01, 0.)\n",
    "        \n",
    "        \n",
    "        triplets.append(triplet)\n",
    "    \n",
    "    return tf.reduce_mean(triplets), cluster_labels\n",
    "\n",
    "def train_step(x, y,cluster_labels):\n",
    "    \n",
    "    with tf.GradientTape() as tape:\n",
    "        loss, cluster_labels = loss_fn(x, y, cluster_labels,  training=True)\n",
    "    \n",
    "    grads = tape.gradient(loss, model.trainable_weights)\n",
    "    optimizer.apply_gradients(zip(grads, model.trainable_weights))\n",
    "    \n",
    "\n",
    "    return loss, cluster_labels\n",
    "\n",
    "def test_step(x, y, cluster_labels):\n",
    "    loss, cluster_labels = loss_fn(x, y, cluster_labels,  training=False)\n",
    "    return loss,  cluster_labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 350\n",
    "best_acc = 0.\n",
    "for epoch in range(epochs):\n",
    "    print('---------------------------------------------------------------------------------')\n",
    "    # Iterate over the batches of the dataset.\n",
    "    \n",
    "    train_cluster_labels, test_cluster_labels = None, None\n",
    "    \n",
    "    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n",
    "        \n",
    "        train_loss, train_cluster_labels = train_step(x_batch_train, y_batch_train, train_cluster_labels)\n",
    "\n",
    "    # Run a validation loop at the end of each epoch.                                                                                                                       \n",
    "    for step, (x_batch_test, y_batch_test) in enumerate(test_dataset):\n",
    "        test_loss, test_cluster_labels = test_step(x_batch_test, y_batch_test, test_cluster_labels)\n",
    "    \n",
    "    \n",
    "    K, acc = measure(model)\n",
    "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                    \n",
    "    if acc>=best_acc:\n",
    "        model.save_weights(checkpoint_filepath)\n",
    "        best_acc = acc\n",
    "     \n",
    "    print(\"epoch: %d\" % (epoch),\"  train loss: %.4f\" % (float(train_loss)), \n",
    "          \"  test loss: %.4f\" % (float(test_loss)),\"  test acc: %.4f\" % (float(acc)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.load_weights(checkpoint_filepath)\n",
    "K, acc = measure(model)\n",
    "\n",
    "print('acc: {}'.format(acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyPu84wYnnCSF9QmveCX/MYD",
   "mount_file_id": "1arPZJvLbz2P-NGFUev-JQExkQ4038eot",
   "name": "cam.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
